//file:src/system/lpc/share.c
//autor:jiangxinpeng
//time:2021.3.11
//copyright:(C) 2020-2050 by jiangxinpeng,All right are reserved.

#include <os/memcache.h>
#include <os/memspace.h>
#include <os/spinlock.h>
#include <os/schedule.h>
#include <os/share.h>
#include <os/debug.h>
#include <os/semaphore.h>
#include <os/safety.h>
#include <os/fifo.h>
#include <sys/lpc.h>
#include <sys/ipc.h>
#include <lib/errno.h>
#include <lib/string.h>

share_mem_t *share_mem_table;
DEFINE_SEMAPHORE(share_mem_mutex, 1);

//init share mem table
void ShareMemInit()
{
    share_mem_table = (share_mem_t *)KMemAlloc(sizeof(share_mem_t) * SHARE_MEM_NUM);
    if (share_mem_table != NULL)
    {
        for (int i = 0; i < SHARE_MEM_NUM; i++)
        {
            share_mem_table[i].id = i;
            share_mem_table[i].flags = 0;
            share_mem_table[i].count = 0;
            share_mem_table[i].pybase = 0;
            SemaphoreInit(&share_mem_table[i].mutex, 0);
            AtomicSet(&share_mem_table[i].ref, 0);
            memset(share_mem_table[i].name, 0, SHARE_MEM_NAME_LEN);
        }
    }
    KPrint("[sharemem] init share memory ok!\n");
}

//alloc a assign size share mem
share_mem_t *ShareMemAlloc(char *name, uint32_t size)
{
    share_mem_t *share;

    for (int i = 0; i < SHARE_MEM_NUM; i++)
    {
        share = share_mem_table + i;
        if (!share->name[0] != '\0')
        {
            if (!size)
            {
                size = 1;
            }
            size = PageAlign(size);
            share->count = size / PAGE_SIZE;
            share->flags = 0;
            share->pybase = 0;
            strcpy(share->name, name);
            return share;
        }
    }
    return NULL;
}

//free a share mem
int ShareMemFree(share_mem_t *share)
{
    if (share->pybase && !share->flags & SHARE_MEM_PRIVATE)
    {
        if (FreePage(share->pybase) != 0)
        {
            return 0;
        }
    }
    memset(share->name, 0, SHARE_MEM_NAME_LEN);
    return 0;
}

//map share mem to vmm space
void *ShareMemMap(int shmid, void *shmaddr, int shmflags)
{
    task_t *cur = cur_task;
    uint64_t addr;
    share_mem_t *share;
    SemaphoreDown(&share_mem_mutex);
    share = ShareMemFindById(shmid); //find share object
    SemaphoreUp(&share_mem_mutex);
    if (!share)
    {
        KPrint("share mem %d not found!\n", shmid);
        return (void *)NULL;
    }
    uint64_t len = share->count * PAGE_SIZE;
    //auto switch unmap virtual address
    if (!shmaddr)
    {
        addr = MemSpaceGetUnMap(cur->vmm, share->count * PAGE_SIZE);
        if (!addr)
            return NULL;
        if (addr < cur->vmm->map_start || addr + len >= cur->vmm->map_end)
            return NULL;
        if (MemSpaceFindIntersection(cur->vmm, (uintptr_t)addr, addr + len))
            return NULL;
        //no alloc pypage
        if (!share->pybase)
        {
            share->pybase = (uintptr_t)AllocUserPage(share->count);
            if (share->pybase)
                return NULL;
        }
        uint64_t flags = MEM_SPACE_MAP_FIXED | MEM_SPACE_MAP_SHARE;
        if (shmflags & IPC_REMAP)
        {
            flags |= MEM_SPACE_MAP_REMAP;
        }
        //map pypage to a unmap viraddr
        shmaddr = MemSpaceMap(cur->vmm, addr, share->pybase, share->count * PAGE_SIZE, PROTE_USER | PROTE_WRITE, flags);
    }
    else
    {
        uint64_t vaddr;
        if (shmflags & IPC_RND)
        {
            vaddr = (uintptr_t)shmaddr & PAGEALIGN_MASK;
        }
        else
        {
            vaddr = (uintptr_t)shmaddr;
        }
        if (!share->pybase) //get pybase
        {
            share->pybase = Vbase2Pybase(vaddr);
            if (!share->pybase)
                return NULL;
            share->flags |= SHARE_MEM_PRIVATE;
        }
        shmaddr = (void *)(uintptr_t)vaddr;
    }
    if (shmaddr)
        AtomicInc(&share->ref);
    return shmaddr;
}

int ShareMemUnmap(const void *shmaddr, int shmflag)
{
    if (!shmaddr)
        return -1;
    task_t *cur = cur_task;
    uint64_t addr;
    if (shmflag & IPC_RND)
        addr = (uintptr_t)shmaddr & PAGEALIGN_MASK;
    else
        addr = (uintptr_t)shmaddr;
    mem_space_t *space = MemSpaceFind(cur->vmm, (uintptr_t)shmaddr);
    if (!space)
    {
        KPrint("no found share mem space %s \n", shmaddr);
        return -1;
    }
    //find share mem object by pybase
    addr = Vbase2Pybase((uintptr_t)shmaddr);
    SemaphoreDown(&share_mem_mutex);
    share_mem_t *share = ShareMemFindByAddr(addr);
    SemaphoreUp(&share_mem_mutex);

    int ret;
    //no private space address
    if (!(share->flags & SHARE_MEM_PRIVATE))
    {
        ret = MemSpaceUnmap(cur->vmm, space->start, space->end - space->start);
    }
    if (!ret)
    {
        if (share)
            AtomicDec(&share->ref);
        else
            KPrint("do unmap at pybase %x virbase %x failed!\n", addr, shmaddr);
    }
    return ret;
}

//search share mem by name
share_mem_t *ShareMemFindByName(char *name)
{
    share_mem_t *share;

    for (int i = 0; i < SHARE_MEM_NUM; i++)
    {
        share = share_mem_table + i;
        if (*share->name != '\0')
        {
            if (!strcmp(share->name, name))
                return share;
        }
    }
    return NULL;
}

share_mem_t *ShareMemFindById(uint32_t id)
{
    share_mem_t *share;

    for (int i = 0; i < SHARE_MEM_NUM; i++)
    {
        share = share_mem_table + i;
        if (share->id == id && *share->name != '\0')
            return share;
    }
    return NULL;
}

share_mem_t *ShareMemFindByAddr(uint32_t pybase)
{
    share_mem_t *share;

    for (int i = 0; i < SHARE_MEM_NUM; i++)
    {
        share = share_mem_table + i;
        if (share->pybase == pybase && *share->name != '\0')
            return share;
    }
    return NULL;
}

int ShareMemInc(int id)
{
    share_mem_t *sharemem;
    SemaphoreDown(&share_mem_mutex);
    sharemem = ShareMemFindById(id);
    if (sharemem)
    {
        AtomicInc(&sharemem->ref);
        SemaphoreUp(&share_mem_mutex);
        return 0;
    }
    SemaphoreUp(&share_mem_mutex);
    return -1;
}

int ShareMemDec(int id)
{
    share_mem_t *sharemem;
    SemaphoreDown(&share_mem_mutex);
    sharemem = ShareMemFindById(id);
    if (sharemem)
    {
        AtomicDec(&sharemem->ref);
        SemaphoreUp(&share_mem_mutex);
        return 0;
    }
    SemaphoreUp(&share_mem_mutex);
    return -1;
}

int ShareMemGet(char *name, uint32_t size, uint32_t flags)
{
    int create_new;
    int ret;
    share_mem_t *sharemem;

    if (!name)
        return -EINVAL;
    if (size > 0 && PageAlign(size) >= SHARE_MEM_MAX_SIZE)
        return -EINVAL;
    SemaphoreDown(&share_mem_mutex);
    if (flags & IPC_CREATE)
    {
        if (flags & IPC_EXCL)
            create_new = 1;
        sharemem = ShareMemFindByName(name);
        if (sharemem)
        {
            if (create_new)
            {
                goto err;
            }
            ret = sharemem->id;
        }
        else
        {
            sharemem = ShareMemAlloc(name, size);
            if (!sharemem)
                goto err;
            ret = sharemem->id;
        }
    }
err:
    SemaphoreUp(&share_mem_mutex);
    return ret;
}

int SysShareMemGet(char *name, uint32_t size, uint32_t flags)
{
    if (!name)
        return -EINVAL;
    if (SafetyCheckRange(name, SHARE_MEM_NAME_LEN) < 0)
        return -EINVAL;
    return ShareMemGet(name, size, flags);
}

int ShareMemPut(int id)
{
    share_mem_t *sharemem;
    SemaphoreDown(&share_mem_mutex);
    sharemem = ShareMemFindById(id);
    if (!sharemem)
    {
        if (AtomicGet(&sharemem->ref) <= 0)
            ShareMemFree(sharemem);
        SemaphoreUp(&share_mem_mutex);
        return 0;
    }
    SemaphoreUp(&share_mem_mutex);
    return -1;
}

int ShareMemDown(int id, int flags)
{
    share_mem_t *sharemem;
    SemaphoreDown(&share_mem_mutex);
    sharemem = ShareMemFindById(id);
    SemaphoreUp(&share_mem_mutex);
    if (!sharemem)
    {
        return -1;
    }
    if (flags & IPC_NOWAIT)
    {
        if (SemaphoreTryDown(&sharemem->mutex) < 0)
            return -1;
        else
            SemaphoreDown(&sharemem->mutex);
    }
    return 0;
}

int ShareMemUp(int id)
{
    share_mem_t *sharemem;
    SemaphoreDown(&share_mem_mutex);
    sharemem = ShareMemFindById(id);
    SemaphoreUp(&share_mem_mutex);
    if (!sharemem)
        return -1;
    SemaphoreUp(&sharemem->mutex);
    return 0;
}

int SysShareMemPut(int id)
{
    return ShareMemPut(id);
}

int SysShareMemDown(int id, int flags)
{
    return ShareMemDown(id, flags);
}

int SysShareMemUp(int id)
{
    return ShareMemUp(id);
}

void* SysShareMemMap(int shmid, void *shmaddr, int shmflag)
{
    return ShareMemMap(shmid, (void *)(uintptr_t)shmaddr, shmflag);
}

int SysShareMemUnmap(void *shmaddr, int shmflag)
{
    return ShareMemUnmap(shmaddr, shmflag);
}